{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>\n",
    "    <p style=\"text-align:center\">\n",
    "        <img alt=\"phoenix logo\" src=\"https://storage.googleapis.com/arize-assets/phoenix/assets/phoenix-logo-light.svg\" width=\"200\"/>\n",
    "        <br>\n",
    "        <a href=\"https://arize.com/docs/phoenix/\">Docs</a>\n",
    "        |\n",
    "        <a href=\"https://github.com/Arize-ai/phoenix\">GitHub</a>\n",
    "        |\n",
    "        <a href=\"https://arize-ai.slack.com/join/shared_invite/zt-2w57bhem8-hq24MB6u7yE_ZF_ilOYSBw#/shared-invite/email\">Community</a>\n",
    "    </p>\n",
    "</center>\n",
    "<h1 align=\"center\">Evaluating an Agent</h1>\n",
    "\n",
    "This notebook serves as an end-to-end example of how to trace and evaluate an agent. The example uses a \"talk-to-your-data\" agent as its example.\n",
    "\n",
    "The notebook includes:\n",
    "* Manually instrumenting an agent using Phoenix decorators\n",
    "* Evaluating function calling accuracy using LLM as a Judge\n",
    "* Evaluating function calling accuracy by comparing to ground truth\n",
    "* Evaluating SQL query generation\n",
    "* Evaluating Python code generation\n",
    "* Evaluating the path of an agent\n",
    "\n",
    "## Install Dependencies, Import Libraries, Set API Keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install -q openai arize-phoenix openinference-instrumentation-openai python-dotenv duckdb openinference-instrumentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dotenv\n",
    "\n",
    "dotenv.load_dotenv()\n",
    "\n",
    "import json\n",
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "import duckdb\n",
    "import pandas as pd\n",
    "from IPython.display import Markdown\n",
    "from openai import OpenAI\n",
    "from openinference.instrumentation import (\n",
    "    suppress_tracing,\n",
    ")\n",
    "from openinference.instrumentation.openai import OpenAIInstrumentor\n",
    "from opentelemetry.trace import StatusCode\n",
    "from pydantic import BaseModel, Field\n",
    "from tqdm import tqdm\n",
    "\n",
    "from phoenix.otel import register"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not (openai_api_key := os.getenv(\"OPENAI_API_KEY\")):\n",
    "    openai_api_key = getpass(\"🔑 Enter your OpenAI API key: \")\n",
    "os.environ[\"OPENAI_API_KEY\"] = openai_api_key\n",
    "\n",
    "client = OpenAI()\n",
    "model = \"gpt-4o-mini\"\n",
    "project_name = \"talk-to-your-data-agent\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Enable Phoenix Tracing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sign up for a free instance of [Phoenix Cloud](https://app.phoenix.arize.com) to get your API key. If you'd prefer, you can instead [self-host Phoenix](https://arize.com/docs/phoenix/deployment)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not (phoenix_endpoint := os.getenv(\"PHOENIX_COLLECTOR_ENDPOINT\")):\n",
    "    phoenix_endpoint = getpass(\"🔑 Enter your Phoenix Collector Endpoint: \")\n",
    "os.environ[\"PHOENIX_COLLECTOR_ENDPOINT\"] = phoenix_endpoint\n",
    "\n",
    "\n",
    "if not (phoenix_api_key := os.getenv(\"PHOENIX_API_KEY\")):\n",
    "    phoenix_api_key = getpass(\"🔑 Enter your Phoenix API key: \")\n",
    "os.environ[\"PHOENIX_API_KEY\"] = phoenix_api_key"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/phoenix/otel/otel.py:333: UserWarning: Could not infer collector endpoint protocol, defaulting to HTTP.\n",
      "  warnings.warn(\"Could not infer collector endpoint protocol, defaulting to HTTP.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🔭 OpenTelemetry Tracing Details 🔭\n",
      "|  Phoenix Project: talk-to-your-data-agent\n",
      "|  Span Processor: SimpleSpanProcessor\n",
      "|  Collector Endpoint: https://app.phoenix.arize.com/s/schavali/v1/traces\n",
      "|  Transport: HTTP + protobuf\n",
      "|  Transport Headers: {'authorization': '****'}\n",
      "|  \n",
      "|  Using a default SpanProcessor. `add_span_processor` will overwrite this default.\n",
      "|  \n",
      "|  ⚠️ WARNING: It is strongly advised to use a BatchSpanProcessor in production environments.\n",
      "|  \n",
      "|  `register` has set this TracerProvider as the global OpenTelemetry default.\n",
      "|  To disable this behavior, call `register` with `set_global_tracer_provider=False`.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tracer_provider = register(project_name=project_name, auto_instrument=True)\n",
    "\n",
    "tracer = tracer_provider.get_tracer(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare dataset\n",
    "\n",
    "Your agent will interact with a local database. Start by loading in that data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Store_Number</th>\n",
       "      <th>SKU_Coded</th>\n",
       "      <th>Product_Class_Code</th>\n",
       "      <th>Sold_Date</th>\n",
       "      <th>Qty_Sold</th>\n",
       "      <th>Total_Sale_Value</th>\n",
       "      <th>On_Promo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1320</td>\n",
       "      <td>6172800</td>\n",
       "      <td>22875</td>\n",
       "      <td>2021-11-02</td>\n",
       "      <td>3</td>\n",
       "      <td>56.849998</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2310</td>\n",
       "      <td>6172800</td>\n",
       "      <td>22875</td>\n",
       "      <td>2021-11-03</td>\n",
       "      <td>1</td>\n",
       "      <td>18.950001</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3080</td>\n",
       "      <td>6172800</td>\n",
       "      <td>22875</td>\n",
       "      <td>2021-11-03</td>\n",
       "      <td>1</td>\n",
       "      <td>18.950001</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2310</td>\n",
       "      <td>6172800</td>\n",
       "      <td>22875</td>\n",
       "      <td>2021-11-06</td>\n",
       "      <td>1</td>\n",
       "      <td>18.950001</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4840</td>\n",
       "      <td>6172800</td>\n",
       "      <td>22875</td>\n",
       "      <td>2021-11-07</td>\n",
       "      <td>1</td>\n",
       "      <td>18.950001</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Store_Number  SKU_Coded  Product_Class_Code   Sold_Date  Qty_Sold  \\\n",
       "0          1320    6172800               22875  2021-11-02         3   \n",
       "1          2310    6172800               22875  2021-11-03         1   \n",
       "2          3080    6172800               22875  2021-11-03         1   \n",
       "3          2310    6172800               22875  2021-11-06         1   \n",
       "4          4840    6172800               22875  2021-11-07         1   \n",
       "\n",
       "   Total_Sale_Value  On_Promo  \n",
       "0         56.849998         0  \n",
       "1         18.950001         0  \n",
       "2         18.950001         0  \n",
       "3         18.950001         0  \n",
       "4         18.950001         0  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "store_sales_df = pd.read_parquet(\n",
    "    \"https://storage.googleapis.com/arize-phoenix-assets/datasets/unstructured/llm/llama-index/Store_Sales_Price_Elasticity_Promotions_Data.parquet\"\n",
    ")\n",
    "store_sales_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the tools\n",
    "\n",
    "Now you can define your agent tools."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tool 1: Database Lookup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "SQL_GENERATION_PROMPT = \"\"\"\n",
    "Generate an SQL query based on a prompt. Do not reply with anything besides the SQL query.\n",
    "The prompt is: {prompt}\n",
    "\n",
    "The available columns are: {columns}\n",
    "The table name is: {table_name}\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def generate_sql_query(prompt: str, columns: list, table_name: str) -> str:\n",
    "    \"\"\"Generate an SQL query based on a prompt\"\"\"\n",
    "    formatted_prompt = SQL_GENERATION_PROMPT.format(\n",
    "        prompt=prompt, columns=columns, table_name=table_name\n",
    "    )\n",
    "\n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[{\"role\": \"user\", \"content\": formatted_prompt}],\n",
    "    )\n",
    "\n",
    "    return response.choices[0].message.content\n",
    "\n",
    "\n",
    "@tracer.tool()\n",
    "def lookup_sales_data(prompt: str) -> str:\n",
    "    \"\"\"Implementation of sales data lookup from parquet file using SQL\"\"\"\n",
    "    try:\n",
    "        table_name = \"sales\"\n",
    "        # Read the parquet file into a DuckDB table\n",
    "        duckdb.sql(f\"CREATE TABLE IF NOT EXISTS {table_name} AS SELECT * FROM store_sales_df\")\n",
    "\n",
    "        print(store_sales_df.columns)\n",
    "        print(table_name)\n",
    "        sql_query = generate_sql_query(prompt, store_sales_df.columns, table_name)\n",
    "        sql_query = sql_query.strip()\n",
    "        sql_query = sql_query.replace(\"```sql\", \"\").replace(\"```\", \"\")\n",
    "\n",
    "        with tracer.start_as_current_span(\n",
    "            \"execute_sql_query\", openinference_span_kind=\"chain\"\n",
    "        ) as span:\n",
    "            span.set_input(value=sql_query)\n",
    "\n",
    "            # Execute the SQL query\n",
    "            result = duckdb.sql(sql_query).df()\n",
    "            span.set_output(value=str(result))\n",
    "            span.set_status(StatusCode.OK)\n",
    "        return result.to_string()\n",
    "    except Exception as e:\n",
    "        return f\"Error accessing data: {str(e)}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['Store_Number', 'SKU_Coded', 'Product_Class_Code', 'Sold_Date',\n",
      "       'Qty_Sold', 'Total_Sale_Value', 'On_Promo'],\n",
      "      dtype='object')\n",
      "sales\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'    Store_Number  SKU_Coded  Product_Class_Code  Sold_Date  Qty_Sold  Total_Sale_Value  On_Promo\\n0           1320    6173050               22875 2021-11-01         1          4.990000         0\\n1           1320    6174250               22875 2021-11-01         1          0.890000         0\\n2           1320    6176200               22975 2021-11-01         2         99.980003         0\\n3           1320    6176800               22800 2021-11-01         1         14.970000         0\\n4           1320    6177250               22975 2021-11-01         1          6.890000         0\\n5           1320    6177300               22800 2021-11-01         1          9.990000         0\\n6           1320    6177350               22800 2021-11-01         2         16.980000         0\\n7           1320    6177700               22875 2021-11-01         1          3.190000         0\\n8           1320    6178000               22875 2021-11-01         2          6.380000         0\\n9           1320    6178250               22800 2021-11-01         1         16.590000         0\\n10          1320    6179250               24400 2021-11-01         1         14.990000         0\\n11          1320    6179300               22800 2021-11-01         2          9.980000         0\\n12          1320    6179400               24400 2021-11-01         2         29.980000         0\\n13          1320    6179450               24400 2021-11-01         1         14.990000         0\\n14          1320    6179500               24400 2021-11-01         1         14.990000         0\\n15          1320    6179750               22800 2021-11-01         2         39.980000         0\\n16          1320    6180550               22975 2021-11-01         1         15.990000         0\\n17          1320    6182050               22975 2021-11-01         1          7.990000         0\\n18          1320    6183750               22850 2021-11-01         3         38.970001         0\\n19          1320    6184100               22975 2021-11-01         3         59.970001         0\\n20          1320    6188550               22950 2021-11-01         2         15.980000         0\\n21          1320    6190050               24425 2021-11-01         5         19.950001         0\\n22          1320    6190150               24425 2021-11-01         1          8.990000         0\\n23          1320    6190200               24425 2021-11-01         1          8.990000         0\\n24          1320    6190250               24425 2021-11-01         1          7.990000         0\\n25          1320    6190350               22950 2021-11-01         1          6.990000         0\\n26          1320    6190400               22950 2021-11-01         1          6.990000         0\\n27          1320    6193750               22875 2021-11-01         1          6.990000         0\\n28          1320    6195350               24375 2021-11-01         1         16.990000         0\\n29          1320    6195800               22850 2021-11-01         3         25.719999         1'"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_data = lookup_sales_data(\"Show me all the sales for store 1320 on November 1st, 2021\")\n",
    "example_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tool 2: Data Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VisualizationConfig(BaseModel):\n",
    "    chart_type: str = Field(..., description=\"Type of chart to generate\")\n",
    "    x_axis: str = Field(..., description=\"Name of the x-axis column\")\n",
    "    y_axis: str = Field(..., description=\"Name of the y-axis column\")\n",
    "    title: str = Field(..., description=\"Title of the chart\")\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def extract_chart_config(data: str, visualization_goal: str) -> dict:\n",
    "    \"\"\"Generate chart visualization configuration\n",
    "\n",
    "    Args:\n",
    "        data: String containing the data to visualize\n",
    "        visualization_goal: Description of what the visualization should show\n",
    "\n",
    "    Returns:\n",
    "        Dictionary containing line chart configuration\n",
    "    \"\"\"\n",
    "    prompt = f\"\"\"Generate a chart configuration based on this data: {data}\n",
    "    The goal is to show: {visualization_goal}\"\"\"\n",
    "\n",
    "    response = client.beta.chat.completions.parse(\n",
    "        model=model,\n",
    "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "        response_format=VisualizationConfig,\n",
    "    )\n",
    "\n",
    "    try:\n",
    "        # Extract axis and title info from response\n",
    "        content = response.choices[0].message.content\n",
    "\n",
    "        # Return structured chart config\n",
    "        return {\n",
    "            \"chart_type\": content.chart_type,\n",
    "            \"x_axis\": content.x_axis,\n",
    "            \"y_axis\": content.y_axis,\n",
    "            \"title\": content.title,\n",
    "            \"data\": data,\n",
    "        }\n",
    "    except Exception:\n",
    "        return {\n",
    "            \"chart_type\": \"line\",\n",
    "            \"x_axis\": \"date\",\n",
    "            \"y_axis\": \"value\",\n",
    "            \"title\": visualization_goal,\n",
    "            \"data\": data,\n",
    "        }\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def create_chart(config: VisualizationConfig) -> str:\n",
    "    \"\"\"Create a chart based on the configuration\"\"\"\n",
    "    prompt = f\"\"\"Write python code to create a chart based on the following configuration.\n",
    "    Only return the code, no other text.\n",
    "    config: {config}\"\"\"\n",
    "\n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "    )\n",
    "\n",
    "    code = response.choices[0].message.content\n",
    "    code = code.replace(\"```python\", \"\").replace(\"```\", \"\")\n",
    "    code = code.strip()\n",
    "\n",
    "    return code\n",
    "\n",
    "\n",
    "@tracer.tool()\n",
    "def generate_visualization(data: str, visualization_goal: str) -> str:\n",
    "    \"\"\"Generate a visualization based on the data and goal\"\"\"\n",
    "    config = extract_chart_config(data, visualization_goal)\n",
    "    code = create_chart(config)\n",
    "    return code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# code = generate_visualization(example_data, \"A line chart of sales over each day in november.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tracer.tool()\n",
    "def run_python_code(code: str) -> str:\n",
    "    \"\"\"Execute Python code in a restricted environment\"\"\"\n",
    "    # Create restricted globals/locals dictionaries with plotting libraries\n",
    "    restricted_globals = {\n",
    "        \"__builtins__\": {\n",
    "            \"print\": print,\n",
    "            \"len\": len,\n",
    "            \"range\": range,\n",
    "            \"sum\": sum,\n",
    "            \"min\": min,\n",
    "            \"max\": max,\n",
    "            \"int\": int,\n",
    "            \"float\": float,\n",
    "            \"str\": str,\n",
    "            \"list\": list,\n",
    "            \"dict\": dict,\n",
    "            \"tuple\": tuple,\n",
    "            \"set\": set,\n",
    "            \"round\": round,\n",
    "            \"__import__\": __import__,\n",
    "            \"json\": __import__(\"json\"),\n",
    "        },\n",
    "        \"plt\": __import__(\"matplotlib.pyplot\"),\n",
    "        \"pd\": __import__(\"pandas\"),\n",
    "        \"np\": __import__(\"numpy\"),\n",
    "        \"sns\": __import__(\"seaborn\"),\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        # Execute code in restricted environment\n",
    "        exec_locals = {}\n",
    "        exec(code, restricted_globals, exec_locals)\n",
    "\n",
    "        # Capture any printed output or return the plot\n",
    "        exec_locals.get(\"__builtins__\", {}).get(\"_\", \"\")\n",
    "        if \"plt\" in exec_locals:\n",
    "            return exec_locals[\"plt\"]\n",
    "\n",
    "        # Try to parse output as JSON before returning\n",
    "        return \"Code executed successfully\"\n",
    "\n",
    "    except Exception as e:\n",
    "        return f\"Error executing code: {str(e)}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tool 3: Data Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tracer.tool()\n",
    "def analyze_sales_data(prompt: str, data: str) -> str:\n",
    "    \"\"\"Implementation of AI-powered sales data analysis\"\"\"\n",
    "    # Construct prompt based on analysis type and data subset\n",
    "    prompt = f\"\"\"Analyze the following data: {data}\n",
    "    Your job is to answer the following question: {prompt}\"\"\"\n",
    "\n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=[{\"role\": \"user\", \"content\": prompt}],\n",
    "    )\n",
    "\n",
    "    analysis = response.choices[0].message.content\n",
    "    return analysis if analysis else \"No analysis could be generated\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# analysis = analyze_sales_data(\"What is the most popular product SKU?\", example_data)\n",
    "# analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tool Schema:\n",
    "\n",
    "You'll need to pass your tool descriptions into your agent router. The following code allows you to easily do so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define tools/functions that can be called by the model\n",
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"lookup_sales_data\",\n",
    "            \"description\": \"Look up data from Store Sales Price Elasticity Promotions dataset\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"prompt\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The unchanged prompt that the user provided.\",\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"prompt\"],\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"analyze_sales_data\",\n",
    "            \"description\": \"Analyze sales data to extract insights\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"data\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The lookup_sales_data tool's output.\",\n",
    "                    },\n",
    "                    \"prompt\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The unchanged prompt that the user provided.\",\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"data\", \"prompt\"],\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"generate_visualization\",\n",
    "            \"description\": \"Generate Python code to create data visualizations\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"data\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The lookup_sales_data tool's output.\",\n",
    "                    },\n",
    "                    \"visualization_goal\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The goal of the visualization.\",\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"data\", \"visualization_goal\"],\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "    # {\n",
    "    #     \"type\": \"function\",\n",
    "    #     \"function\": {\n",
    "    #         \"name\": \"run_python_code\",\n",
    "    #         \"description\": \"Run Python code in a restricted environment\",\n",
    "    #         \"parameters\": {\n",
    "    #             \"type\": \"object\",\n",
    "    #             \"properties\": {\n",
    "    #                 \"code\": {\"type\": \"string\", \"description\": \"The Python code to run.\"}\n",
    "    #             },\n",
    "    #             \"required\": [\"code\"]\n",
    "    #         }\n",
    "    #     }\n",
    "    # }\n",
    "]\n",
    "\n",
    "# Dictionary mapping function names to their implementations\n",
    "tool_implementations = {\n",
    "    \"lookup_sales_data\": lookup_sales_data,\n",
    "    \"analyze_sales_data\": analyze_sales_data,\n",
    "    \"generate_visualization\": generate_visualization,\n",
    "    # \"run_python_code\": run_python_code\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agent logic\n",
    "\n",
    "With the tools defined, you're ready to define the main routing and tool call handling steps of your agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tracer.chain()\n",
    "def handle_tool_calls(tool_calls, messages):\n",
    "    for tool_call in tool_calls:\n",
    "        function = tool_implementations[tool_call.function.name]\n",
    "        function_args = json.loads(tool_call.function.arguments)\n",
    "        result = function(**function_args)\n",
    "\n",
    "        messages.append({\"role\": \"tool\", \"content\": result, \"tool_call_id\": tool_call.id})\n",
    "    return messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def start_main_span(messages):\n",
    "    print(\"Starting main span with messages:\", messages)\n",
    "\n",
    "    with tracer.start_as_current_span(\"AgentRun\", openinference_span_kind=\"agent\") as span:\n",
    "        span.set_input(value=messages)\n",
    "        ret = run_agent(messages)\n",
    "        print(\"Main span completed with return value:\", ret)\n",
    "        span.set_output(value=ret)\n",
    "        span.set_status(StatusCode.OK)\n",
    "        return ret\n",
    "\n",
    "\n",
    "def run_agent(messages):\n",
    "    print(\"Running agent with messages:\", messages)\n",
    "    if isinstance(messages, str):\n",
    "        messages = [{\"role\": \"user\", \"content\": messages}]\n",
    "        print(\"Converted string message to list format\")\n",
    "\n",
    "    # Check and add system prompt if needed\n",
    "    if not any(\n",
    "        isinstance(message, dict) and message.get(\"role\") == \"system\" for message in messages\n",
    "    ):\n",
    "        system_prompt = {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful assistant that can answer questions about the Store Sales Price Elasticity Promotions dataset.\",\n",
    "        }\n",
    "        messages.append(system_prompt)\n",
    "        print(\"Added system prompt to messages\")\n",
    "\n",
    "    while True:\n",
    "        # Router call span\n",
    "        print(\"Starting router call span\")\n",
    "        with tracer.start_as_current_span(\n",
    "            \"router_call\",\n",
    "            openinference_span_kind=\"chain\",\n",
    "        ) as span:\n",
    "            span.set_input(value=messages)\n",
    "\n",
    "            response = client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=messages,\n",
    "                tools=tools,\n",
    "            )\n",
    "\n",
    "            messages.append(response.choices[0].message.model_dump())\n",
    "            tool_calls = response.choices[0].message.tool_calls\n",
    "            print(\"Received response with tool calls:\", bool(tool_calls))\n",
    "            span.set_status(StatusCode.OK)\n",
    "\n",
    "            if tool_calls:\n",
    "                # Tool calls span\n",
    "                print(\"Processing tool calls\")\n",
    "                messages = handle_tool_calls(tool_calls, messages)\n",
    "                span.set_output(value=tool_calls)\n",
    "            else:\n",
    "                print(\"No tool calls, returning final response\")\n",
    "                span.set_output(value=response.choices[0].message.content)\n",
    "\n",
    "                return response.choices[0].message.content"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the agent\n",
    "\n",
    "Your agent is now good to go! Let's try it out with some example questions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting main span with messages: [{'role': 'user', 'content': 'Create a dot plot chart showing sales in 2021'}]\n",
      "Running agent with messages: [{'role': 'user', 'content': 'Create a dot plot chart showing sales in 2021'}]\n",
      "Added system prompt to messages\n",
      "Starting router call span\n",
      "Received response with tool calls: True\n",
      "Processing tool calls\n",
      "Index(['Store_Number', 'SKU_Coded', 'Product_Class_Code', 'Sold_Date',\n",
      "       'Qty_Sold', 'Total_Sale_Value', 'On_Promo'],\n",
      "      dtype='object')\n",
      "sales\n",
      "Starting router call span\n",
      "Received response with tool calls: True\n",
      "Processing tool calls\n",
      "Starting router call span\n",
      "Received response with tool calls: False\n",
      "No tool calls, returning final response\n",
      "Main span completed with return value: I have created a dot plot chart showing sales data for the year 2021. The chart displays the total quantity sold over the specified dates. Here is the visualization code you can run in your environment:\n",
      "\n",
      "```python\n",
      "import pandas as pd\n",
      "import matplotlib.pyplot as plt\n",
      "import json\n",
      "\n",
      "data = \"YOUR_DATA_HERE\"\n",
      "\n",
      "# Replace with actual data\n",
      "data = data.replace(\"'\", '\"')\n",
      "df = pd.DataFrame(json.loads(data))\n",
      "\n",
      "df['Sold_Date'] = pd.to_datetime(df['Sold_Date'])\n",
      "df = df.sort_values('Sold_Date')\n",
      "\n",
      "plt.figure(figsize=(14, 7))\n",
      "plt.plot(df['Sold_Date'], df['Total_Qty_Sold'], marker='o')\n",
      "plt.title('Create a dot plot chart showing sales in 2021')\n",
      "plt.xlabel('Date')\n",
      "plt.ylabel('Total Quantity Sold')\n",
      "plt.xticks(rotation=45)\n",
      "plt.grid()\n",
      "plt.tight_layout()\n",
      "plt.show()\n",
      "```\n",
      "\n",
      "Make sure to replace `\"YOUR_DATA_HERE\"` with the actual data string I provided earlier. This will enable you to visualize the sales data effectively!\n",
      "<IPython.core.display.Markdown object>\n"
     ]
    }
   ],
   "source": [
    "ret = start_main_span(\n",
    "    [{\"role\": \"user\", \"content\": \"Create a dot plot chart showing sales in 2021\"}]\n",
    ")\n",
    "print(Markdown(ret))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_questions = [\n",
    "    \"What was the most popular product SKU?\",\n",
    "    \"What was the total revenue across all stores?\",\n",
    "    \"Which store had the highest sales volume?\",\n",
    "    \"Create a bar chart showing total sales by store\",\n",
    "    \"What percentage of items were sold on promotion?\",\n",
    "    \"Plot daily sales volume over time\",\n",
    "    \"What was the average transaction value?\",\n",
    "    \"Create a box plot of transaction values\",\n",
    "    \"Which products were frequently purchased together?\",\n",
    "    \"Plot a line graph showing the sales trend over time with a 7-day moving average\",\n",
    "]\n",
    "\n",
    "for question in tqdm(agent_questions, desc=\"Processing questions\"):\n",
    "    try:\n",
    "        ret = start_main_span([{\"role\": \"user\", \"content\": question}])\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing question: {question}\")\n",
    "        print(e)\n",
    "        continue"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Agent Traces](https://storage.googleapis.com/arize-phoenix-assets/assets/images/agent-traces.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluating the agent\n",
    "\n",
    "So your agent looks like it's working, but how can you measure its performance?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "OpenAIInstrumentor().uninstrument()  # Uninstrument the OpenAI client to avoid capturing LLM as a Judge evaluation calls in your same project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "import phoenix as px\n",
    "from phoenix.evals import LLM, create_classifier, evaluate_dataframe\n",
    "from phoenix.experiments import evaluate_experiment, run_experiment\n",
    "from phoenix.experiments.evaluators import create_evaluator\n",
    "from phoenix.experiments.types import Example\n",
    "from phoenix.trace.dsl import SpanQuery\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "px_client = px.Client()\n",
    "eval_model = LLM(provider=\"openai\", model=\"gpt-4o-mini\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "TOOL_CALLING_PROMPT_TEMPLATE = \"\"\"\n",
    "You are an evaluation assistant evaluating questions and tool calls to\n",
    "determine whether the tool called would answer the question. The tool\n",
    "calls have been generated by a separate agent, and chosen from the list of\n",
    "tools provided below. It is your job to decide whether that agent chose\n",
    "the right tool to call.\n",
    "\n",
    "    [BEGIN DATA]\n",
    "    ************\n",
    "    [Question]: {question}\n",
    "    ************\n",
    "    [Tool Called]: {tool_call}\n",
    "    [END DATA]\n",
    "\n",
    "Your response must be single word, either \"correct\" or \"incorrect\",\n",
    "and should not contain any text or characters aside from that word.\n",
    "\"incorrect\" means that the chosen tool would not answer the question,\n",
    "the tool includes information that is not presented in the question,\n",
    "or that the tool signature includes parameter values that don't match\n",
    "the formats specified in the tool signatures below.\n",
    "\n",
    "\"correct\" means the correct tool call was chosen, the correct parameters\n",
    "were extracted from the question, the tool call generated is runnable and correct,\n",
    "and that no outside information not present in the question was used\n",
    "in the generated question.\n",
    "\n",
    "    [Tool Definitions]: \"generate_visualization, lookup_sales_data, analyze_sales_data, run_python_code\"\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function Calling Evals using LLM as a Judge\n",
    "\n",
    "This first evaluation will evaluate your agent router choices using another LLM.\n",
    "\n",
    "It follows a standard pattern:\n",
    "1. Export traces from Phoenix\n",
    "2. Prepare those exported traces in a dataframe with the correct columns\n",
    "3. Use `create_evaluator` & `evaluate_dataframe` to run a standard template across each row of that dataframe and produce an eval label\n",
    "4. Upload the results back into Phoenix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = (\n",
    "    SpanQuery()\n",
    "    .where(\n",
    "        \"span_kind == 'LLM'\",\n",
    "    )\n",
    "    .select(question=\"input.value\", output_messages=\"llm.output_messages\")\n",
    ")\n",
    "\n",
    "# The Phoenix Client can take this query and return the dataframe.\n",
    "tool_calls_df = px.Client().query_spans(query, project_name=project_name, timeout=None)\n",
    "tool_calls_df.dropna(subset=[\"output_messages\"], inplace=True)\n",
    "\n",
    "\n",
    "def get_tool_call(outputs):\n",
    "    if outputs[0].get(\"message\").get(\"tool_calls\"):\n",
    "        return (\n",
    "            outputs[0]\n",
    "            .get(\"message\")\n",
    "            .get(\"tool_calls\")[0]\n",
    "            .get(\"tool_call\")\n",
    "            .get(\"function\")\n",
    "            .get(\"name\")\n",
    "        )\n",
    "    else:\n",
    "        return \"No tool used\"\n",
    "\n",
    "\n",
    "tool_calls_df[\"tool_call\"] = tool_calls_df[\"output_messages\"].apply(get_tool_call)\n",
    "tool_calls_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call_evaluator = create_classifier(\n",
    "    name=\"tool_call_eval\",\n",
    "    llm=eval_model,\n",
    "    prompt_template=TOOL_CALLING_PROMPT_TEMPLATE,\n",
    "    rails=[\"correct\", \"incorrect\"],\n",
    ")\n",
    "\n",
    "tool_call_eval = evaluate_dataframe(dataframe=tool_calls_df, evaluators=[tool_call_evaluator])\n",
    "\n",
    "tool_call_eval[\"score\"] = tool_call_eval.apply(\n",
    "    lambda x: 1 if x[\"label\"] == \"correct\" else 0, axis=1\n",
    ")\n",
    "\n",
    "tool_call_eval.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from phoenix.client import AsyncClient\n",
    "\n",
    "px_client = AsyncClient()\n",
    "await px_client.spans.log_span_annotations_dataframe(\n",
    "    dataframe=tool_call_eval, annotation_name=\"Tool Calling Eval\", annotator_kind=\"LLM\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should now see eval labels in Phoenix.\n",
    "\n",
    "# ![Function Calling Evals](https://storage.googleapis.com/arize-phoenix-assets/assets/images/function-calling-evals.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Function Calling Evals using Ground Truth\n",
    "\n",
    "The above example works, however if you have ground truth labled data, you can use that data to get an even more accurate measure of your router's performance by running an experiments.\n",
    "\n",
    "Experiments also follow a standard step-by-step process in Phoenix:\n",
    "1. Create a dataset of test cases, and optionally, expected outputs\n",
    "2. Create a task to run on each test case - usually this is invoking your agent or a specifc step of it\n",
    "3. Create evaluator(s) to run on each output of your task\n",
    "4. Visualize results in Phoenix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import uuid\n",
    "\n",
    "id = str(uuid.uuid4())\n",
    "\n",
    "agent_tool_responses = {\n",
    "    \"What was the most popular product SKU?\": \"lookup_sales_data, analyze_sales_data\",\n",
    "    \"What was the total revenue across all stores?\": \"lookup_sales_data, analyze_sales_data\",\n",
    "    \"Which store had the highest sales volume?\": \"lookup_sales_data, analyze_sales_data\",\n",
    "    \"Create a bar chart showing total sales by store\": \"generate_visualization, lookup_sales_data, run_python_code\",\n",
    "    \"What percentage of items were sold on promotion?\": \"lookup_sales_data, analyze_sales_data\",\n",
    "    \"Plot daily sales volume over time\": \"generate_visualization, lookup_sales_data, run_python_code\",\n",
    "    \"What was the average transaction value?\": \"lookup_sales_data, analyze_sales_data\",\n",
    "    \"Create a box plot of transaction values\": \"generate_visualization, lookup_sales_data, run_python_code\",\n",
    "    \"Which products were frequently purchased together?\": \"lookup_sales_data, analyze_sales_data\",\n",
    "    \"Plot a line graph showing the sales trend over time with a 7-day moving average\": \"generate_visualization, lookup_sales_data, run_python_code\",\n",
    "}\n",
    "\n",
    "tool_calling_df = pd.DataFrame(agent_tool_responses.items(), columns=[\"question\", \"tool_calls\"])\n",
    "dataset = px_client.upload_dataset(\n",
    "    dataframe=tool_calling_df,\n",
    "    dataset_name=f\"tool_calling_ground_truth_{id}\",\n",
    "    input_keys=[\"question\"],\n",
    "    output_keys=[\"tool_calls\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For your task, you can simply run just the router call of your agent:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_router_step(example: Example) -> str:\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful assistant that can answer questions about the Store Sales Price Elasticity Promotions dataset.\",\n",
    "        }\n",
    "    ]\n",
    "    messages.append({\"role\": \"user\", \"content\": example.input.get(\"question\")})\n",
    "\n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        tools=tools,\n",
    "    )\n",
    "    tool_calls = []\n",
    "    for tool_call in response.choices[0].message.tool_calls:\n",
    "        tool_calls.append(tool_call.function.name)\n",
    "    return tool_calls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Your evaluator can also be simple, since you have expected outputs. If you didn't have those expected outputs, you could instead use an LLM as a Judge here, or even basic code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tools_match(expected: str, output: str) -> bool:\n",
    "    expected_tools = expected.get(\"tool_calls\").split(\", \")\n",
    "    return expected_tools == output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = run_experiment(\n",
    "    dataset,\n",
    "    run_router_step,\n",
    "    evaluators=[tools_match],\n",
    "    experiment_name=\"Tool Calling Eval\",\n",
    "    experiment_description=\"Evaluating the tool calling step of the agent\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tool Evals\n",
    "\n",
    "The next piece of your agent to evaluate is its tools. Each tool is usually evaluated differently - we've included some examples below. If you need other ideas, [Phoenix's built-in evaluators](https://arize.com/docs/phoenix/evaluation/how-to-evals/running-pre-tested-evals) give you an idea of other metrics to use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluating our SQL generation tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This step will be replaced by a human annotated set of ground truth data, instead of generated examples\n",
    "\n",
    "db_lookup_questions = [\n",
    "    \"What was the most popular product SKU?\",\n",
    "    \"Which store had the highest total sales value?\",\n",
    "    \"How many items were sold on promotion?\",\n",
    "    \"What was the average quantity sold per transaction?\",\n",
    "    \"Which product class code generated the most revenue?\",\n",
    "    \"What day of the week had the highest sales volume?\",\n",
    "    \"How many unique stores made sales?\",\n",
    "    \"What was the highest single transaction value?\",\n",
    "    \"Which products were frequently sold together?\",\n",
    "    \"What's the trend in sales over time?\",\n",
    "]\n",
    "\n",
    "expected_results = []\n",
    "\n",
    "for question in tqdm(db_lookup_questions, desc=\"Processing SQL lookup questions\"):\n",
    "    try:\n",
    "        with suppress_tracing():\n",
    "            expected_results.append(lookup_sales_data(question))\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing question: {question}\")\n",
    "        print(e)\n",
    "        db_lookup_questions.remove(question)\n",
    "\n",
    "# Create a DataFrame with the questions\n",
    "questions_df = pd.DataFrame({\"question\": db_lookup_questions, \"expected_result\": expected_results})\n",
    "\n",
    "questions_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = px_client.upload_dataset(\n",
    "    dataframe=questions_df,\n",
    "    dataset_name=f\"sales_db_lookup_questions_{id}\",\n",
    "    input_keys=[\"question\"],\n",
    "    output_keys=[\"expected_result\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_sql_query(example: Example) -> str:\n",
    "    with suppress_tracing():\n",
    "        return lookup_sales_data(example.input.get(\"question\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_sql_result(output: str, expected: str) -> bool:\n",
    "    # Extract just the numbers from both strings\n",
    "    result_nums = \"\".join(filter(str.isdigit, output))\n",
    "    expected_nums = \"\".join(filter(str.isdigit, expected.get(\"expected_result\")))\n",
    "    return result_nums == expected_nums"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = run_experiment(\n",
    "    dataset,\n",
    "    run_sql_query,\n",
    "    evaluators=[evaluate_sql_result],\n",
    "    experiment_name=\"SQL Query Eval\",\n",
    "    experiment_description=\"Evaluating the SQL query generation step of the agent\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluating our Python code generation tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace this with a human annotated set of ground truth data, instead of generated examples\n",
    "\n",
    "code_generation_questions = [\n",
    "    \"Create a bar chart showing total sales by store\",\n",
    "    \"Plot daily sales volume over time\",\n",
    "    \"Plot a line graph showing the sales trend over time with a 7-day moving average\",\n",
    "    \"Create a histogram of quantities sold per transaction\",\n",
    "    \"Generate a pie chart showing sales distribution across product classes\",\n",
    "    \"Create a stacked bar chart showing promotional vs non-promotional sales by store\",\n",
    "    \"Generate a heatmap of sales by day of week and store number\",\n",
    "    \"Plot a line chart comparing sales trends between top 5 stores\",\n",
    "]\n",
    "\n",
    "example_data = []\n",
    "chart_configs = []\n",
    "for question in tqdm(code_generation_questions[:], desc=\"Processing code generation questions\"):\n",
    "    try:\n",
    "        with suppress_tracing():\n",
    "            example_data.append(lookup_sales_data(question))\n",
    "            chart_configs.append(json.dumps(extract_chart_config(example_data[-1], question)))\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing question: {question}\")\n",
    "        print(e)\n",
    "        code_generation_questions.remove(question)\n",
    "\n",
    "code_generation_df = pd.DataFrame(\n",
    "    {\n",
    "        \"question\": code_generation_questions,\n",
    "        \"example_data\": example_data,\n",
    "        \"chart_configs\": chart_configs,\n",
    "    }\n",
    ")\n",
    "\n",
    "dataset = px_client.upload_dataset(\n",
    "    dataframe=code_generation_df,\n",
    "    dataset_name=f\"code_generation_questions_{id}\",\n",
    "    input_keys=[\"question\", \"example_data\", \"chart_configs\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_code_generation(example: Example) -> str:\n",
    "    with suppress_tracing():\n",
    "        chart_config = extract_chart_config(\n",
    "            data=example.input.get(\"example_data\"), visualization_goal=example.input.get(\"question\")\n",
    "        )\n",
    "        code = generate_visualization(\n",
    "            visualization_goal=example.input.get(\"question\"), data=example.input.get(\"example_data\")\n",
    "        )\n",
    "\n",
    "    return {\"code\": code, \"chart_config\": chart_config}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, you don't have ground truth data to compare to. Instead you can just use a simple code evaluator: trying to run the generated code and catching any errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def code_is_runnable(output: str) -> bool:\n",
    "    \"\"\"Check if the code is runnable\"\"\"\n",
    "    output = output.get(\"code\")\n",
    "    output = output.strip()\n",
    "    output = output.replace(\"```python\", \"\").replace(\"```\", \"\")\n",
    "    try:\n",
    "        exec(output)\n",
    "        return True\n",
    "    except Exception:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_chart_config(output: str, expected: str) -> bool:\n",
    "    return output.get(\"chart_config\") == expected.get(\"chart_config\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = run_experiment(\n",
    "    dataset,\n",
    "    run_code_generation,\n",
    "    evaluators=[code_is_runnable, evaluate_chart_config],\n",
    "    experiment_name=\"Code Generation Eval\",\n",
    "    experiment_description=\"Evaluating the code generation step of the agent\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluating the agent path and convergence\n",
    "\n",
    "Finally, the last piece of your agent to evaluate is its path. This is important to evaluate to understand how efficient your agent is in its execution. Does it need to call the same tool multiple times? Does it skip steps it shouldn't, and have to backtrack later? Convergence or path evals can tell you this.\n",
    "\n",
    "Convergence evals operate slightly differently. The one you'll use below relies on knowing the minimum number of steps taken by the agent for a given type of query. Instead of just running an experiment, you'll run an experiment then after it completes, attach a second evaluator to calculate convergence.\n",
    "\n",
    "The workflow is as follows:\n",
    "1. Create a dataset of the same type of question, phrased different ways each time - the agent should take the same path for each, but you'll often find it doesn't.\n",
    "2. Create a task that runs the agent on each question, while tracking the number of steps it takes.\n",
    "3. Run the experiment without an evaluator.\n",
    "4. Calculate the minimum number of steps taken to complete the task.\n",
    "5. Create an evaluator that compares the steps taken of each run against that min step number.\n",
    "6. Run this evaluator on your experiment from step 3.\n",
    "7. View your results in Phoenix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace this with a human annotated set of ground truth data, instead of generated examples\n",
    "\n",
    "convergence_questions = [\n",
    "    \"What was the average quantity sold per transaction?\",\n",
    "    \"What is the mean number of items per sale?\",\n",
    "    \"Calculate the typical quantity per transaction\",\n",
    "    \"Show me the average number of units sold in each transaction\",\n",
    "    \"What's the mean transaction size in terms of quantity?\",\n",
    "    \"On average, how many items were purchased per transaction?\",\n",
    "    \"What is the average basket size per sale?\",\n",
    "    \"Calculate the mean number of products per purchase\",\n",
    "    \"What's the typical number of units per order?\",\n",
    "    \"Find the average quantity of items in each transaction\",\n",
    "    \"What is the average number of products bought per purchase?\",\n",
    "    \"Tell me the mean quantity of items in a typical transaction\",\n",
    "    \"How many items does a customer buy on average per transaction?\",\n",
    "    \"What's the usual number of units in each sale?\",\n",
    "    \"Calculate the average basket quantity per order\",\n",
    "    \"What is the typical amount of products per transaction?\",\n",
    "    \"Show the mean number of items customers purchase per visit\",\n",
    "    \"What's the average quantity of units per shopping trip?\",\n",
    "    \"How many products do customers typically buy in one transaction?\",\n",
    "    \"What is the standard basket size in terms of quantity?\",\n",
    "]\n",
    "\n",
    "convergence_df = pd.DataFrame({\"question\": convergence_questions})\n",
    "\n",
    "dataset = px_client.upload_dataset(\n",
    "    dataframe=convergence_df, dataset_name=\"convergence_questions\", input_keys=[\"question\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_message_steps(messages):\n",
    "    \"\"\"\n",
    "    Convert a list of message objects into a readable format that shows the steps taken.\n",
    "\n",
    "    Args:\n",
    "        messages (list): A list of message objects containing role, content, tool calls, etc.\n",
    "\n",
    "    Returns:\n",
    "        str: A readable string showing the steps taken.\n",
    "    \"\"\"\n",
    "    steps = []\n",
    "    for message in messages:\n",
    "        role = message.get(\"role\")\n",
    "        if role == \"user\":\n",
    "            steps.append(f\"User: {message.get('content')}\")\n",
    "        elif role == \"system\":\n",
    "            steps.append(\"System: Provided context\")\n",
    "        elif role == \"assistant\":\n",
    "            if message.get(\"tool_calls\"):\n",
    "                for tool_call in message[\"tool_calls\"]:\n",
    "                    tool_name = tool_call[\"function\"][\"name\"]\n",
    "                    steps.append(f\"Assistant: Called tool '{tool_name}'\")\n",
    "            else:\n",
    "                steps.append(f\"Assistant: {message.get('content')}\")\n",
    "        elif role == \"tool\":\n",
    "            steps.append(f\"Tool response: {message.get('content')}\")\n",
    "\n",
    "    return \"\\n\".join(steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_agent_and_track_path(example: Example) -> str:\n",
    "    print(\"Starting main span with messages:\", example.input.get(\"question\"))\n",
    "    messages = [{\"role\": \"user\", \"content\": example.input.get(\"question\")}]\n",
    "    ret = run_agent_messages(messages)\n",
    "    return {\"path_length\": len(ret), \"messages\": format_message_steps(ret)}\n",
    "\n",
    "\n",
    "def run_agent_messages(messages):\n",
    "    print(\"Running agent with messages:\", messages)\n",
    "    if isinstance(messages, str):\n",
    "        messages = [{\"role\": \"user\", \"content\": messages}]\n",
    "        print(\"Converted string message to list format\")\n",
    "\n",
    "    # Check and add system prompt if needed\n",
    "    if not any(\n",
    "        isinstance(message, dict) and message.get(\"role\") == \"system\" for message in messages\n",
    "    ):\n",
    "        system_prompt = {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful assistant that can answer questions about the Store Sales Price Elasticity Promotions dataset.\",\n",
    "        }\n",
    "        messages.append(system_prompt)\n",
    "        print(\"Added system prompt to messages\")\n",
    "\n",
    "    while True:\n",
    "        # Router call span\n",
    "        print(\"Starting router\")\n",
    "\n",
    "        response = client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=messages,\n",
    "            tools=tools,\n",
    "        )\n",
    "\n",
    "        messages.append(response.choices[0].message.model_dump())\n",
    "        tool_calls = response.choices[0].message.tool_calls\n",
    "        print(\"Received response with tool calls:\", bool(tool_calls))\n",
    "\n",
    "        if tool_calls:\n",
    "            # Tool calls span\n",
    "            print(\"Processing tool calls\")\n",
    "            tool_calls = response.choices[0].message.tool_calls\n",
    "            messages = handle_tool_calls(tool_calls, messages)\n",
    "        else:\n",
    "            print(\"No tool calls, returning final response\")\n",
    "            return messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = run_experiment(\n",
    "    dataset,\n",
    "    run_agent_and_track_path,\n",
    "    experiment_name=\"Convergence Eval\",\n",
    "    experiment_description=\"Evaluating the convergence of the agent\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment.as_dataframe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = experiment.as_dataframe()[\"output\"].to_dict().values()\n",
    "optimal_path_length = min(\n",
    "    output.get(\"path_length\")\n",
    "    for output in outputs\n",
    "    if output and output.get(\"path_length\") is not None\n",
    ")\n",
    "print(f\"The optimal path length is {optimal_path_length}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "@create_evaluator(name=\"Convergence Eval\", kind=\"CODE\")\n",
    "def evaluate_path_length(output: str) -> float:\n",
    "    if output and output.get(\"path_length\"):\n",
    "        return optimal_path_length / float(output.get(\"path_length\"))\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment = evaluate_experiment(experiment, evaluators=[evaluate_path_length])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Advanced - Combining all the evals into our experiment\n",
    "\n",
    "As an optional final step, you can combine all the evaluators and experiments above into a single experiment. This requires some more advanced data wrangling, but gives you a single report on your agent's performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build a version of our agent that tracks all the necessary information for evals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_messages(messages):\n",
    "    tool_calls = []\n",
    "    tool_responses = []\n",
    "    final_output = None\n",
    "\n",
    "    for i, message in enumerate(messages):\n",
    "        # Extract tool calls\n",
    "        if \"tool_calls\" in message and message[\"tool_calls\"]:\n",
    "            for tool_call in message[\"tool_calls\"]:\n",
    "                tool_name = tool_call[\"function\"][\"name\"]\n",
    "                tool_input = tool_call[\"function\"][\"arguments\"]\n",
    "                tool_calls.append(tool_name)\n",
    "\n",
    "                # Prepare tool response structure with tool name and input\n",
    "                tool_responses.append(\n",
    "                    {\"tool_name\": tool_name, \"tool_input\": tool_input, \"tool_response\": None}\n",
    "                )\n",
    "\n",
    "        # Extract tool responses\n",
    "        if message[\"role\"] == \"tool\" and \"tool_call_id\" in message:\n",
    "            for tool_response in tool_responses:\n",
    "                if message[\"tool_call_id\"] in message.values():\n",
    "                    tool_response[\"tool_response\"] = message[\"content\"]\n",
    "\n",
    "        # Extract final output\n",
    "        if (\n",
    "            message[\"role\"] == \"assistant\"\n",
    "            and not message.get(\"tool_calls\")\n",
    "            and not message.get(\"function_call\")\n",
    "        ):\n",
    "            final_output = message[\"content\"]\n",
    "\n",
    "    result = {\n",
    "        \"tool_calls\": tool_calls,\n",
    "        \"tool_responses\": tool_responses,\n",
    "        \"final_output\": final_output,\n",
    "        \"unchanged_messages\": messages,\n",
    "        \"path_length\": len(messages),\n",
    "    }\n",
    "\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_agent_and_track_path_combined(example: Example) -> str:\n",
    "    print(\"Starting main span with messages:\", example.input.get(\"question\"))\n",
    "    messages = [{\"role\": \"user\", \"content\": example.input.get(\"question\")}]\n",
    "    ret = run_agent_messages_combined(messages)\n",
    "    return process_messages(ret)\n",
    "\n",
    "\n",
    "def run_agent_messages_combined(messages):\n",
    "    print(\"Running agent with messages:\", messages)\n",
    "    if isinstance(messages, str):\n",
    "        messages = [{\"role\": \"user\", \"content\": messages}]\n",
    "        print(\"Converted string message to list format\")\n",
    "\n",
    "    # Check and add system prompt if needed\n",
    "    if not any(\n",
    "        isinstance(message, dict) and message.get(\"role\") == \"system\" for message in messages\n",
    "    ):\n",
    "        system_prompt = {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"You are a helpful assistant that can answer questions about the Store Sales Price Elasticity Promotions dataset.\",\n",
    "        }\n",
    "        messages.append(system_prompt)\n",
    "        print(\"Added system prompt to messages\")\n",
    "\n",
    "    while True:\n",
    "        # Router call span\n",
    "        print(\"Starting router\")\n",
    "\n",
    "        response = client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=messages,\n",
    "            tools=tools,\n",
    "        )\n",
    "\n",
    "        messages.append(response.choices[0].message.model_dump())\n",
    "        tool_calls = response.choices[0].message.tool_calls\n",
    "        print(\"Received response with tool calls:\", bool(tool_calls))\n",
    "\n",
    "        if tool_calls:\n",
    "            # Tool calls span\n",
    "            print(\"Processing tool calls\")\n",
    "            tool_calls = response.choices[0].message.tool_calls\n",
    "            messages = handle_tool_calls(tool_calls, messages)\n",
    "        else:\n",
    "            print(\"No tool calls, returning final response\")\n",
    "            return messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_sql_query(\"What was the most popular product SKU?\", store_sales_df.columns, \"sales\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "overall_experiment_questions = [\n",
    "    {\n",
    "        \"question\": \"What was the most popular product SKU?\",\n",
    "        \"sql_result\": \"   SKU_Coded  Total_Qty_Sold 0    6200700         52262.0\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"What was the total revenue across all stores?\",\n",
    "        \"sql_result\": \"   Total_Revenue 0   1.327264e+07\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"Which store had the highest sales volume?\",\n",
    "        \"sql_result\": \"   Store_Number  Total_Sales_Volume 0          2970             59322.0\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"Create a bar chart showing total sales by store\",\n",
    "        \"sql_result\": \"    Store_Number    Total_Sales 0            880  420302.088397 1           1650  580443.007953 2           4180  272208.118542 3            550  229727.498752 4           1100  497509.528013 5           3300  619660.167018 6           3190  335035.018792 7           2970  836341.327191 8           3740  359729.808228 9           2530  324046.518720 10          4400   95745.620250 11          1210  508393.767785 12           330  370503.687331 13          2750  453664.808068 14          1980  242290.828499 15          1760  350747.617798 16          3410  410567.848126 17           990  378433.018639 18          4730  239711.708869 19          4070  322307.968330 20          3080  495458.238811 21          2090  309996.247965 22          1320  592832.067579 23          2640  308990.318559 24          1540  427777.427815 25          4840  389056.668316 26          2860  132320.519487 27          2420  406715.767402 28           770  292968.918642 29          3520  145701.079372 30           660  343594.978075 31          3630  405034.547846 32          2310  412579.388504 33          2200  361173.288199 34          1870  401070.997685\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"What percentage of items were sold on promotion?\",\n",
    "        \"sql_result\": \"   Promotion_Percentage 0              0.625596\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"What was the average transaction value?\",\n",
    "        \"sql_result\": \"   Average_Transaction_Value 0                  19.018132\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"Create a line chart showing sales in 2021\",\n",
    "        \"sql_result\": \"  sale_month  total_quantity_sold  total_sales_value 0 2021-11-01              43056.0      499984.428193 1 2021-12-01              75724.0      910982.118423\",\n",
    "    },\n",
    "]\n",
    "\n",
    "overall_experiment_questions[0][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[0][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "overall_experiment_questions[1][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[1][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "overall_experiment_questions[2][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[2][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "overall_experiment_questions[3][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[3][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "overall_experiment_questions[4][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[4][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "overall_experiment_questions[5][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[5][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "overall_experiment_questions[6][\"sql_generated\"] = generate_sql_query(\n",
    "    overall_experiment_questions[6][\"question\"], store_sales_df.columns, \"sales\"\n",
    ")\n",
    "\n",
    "print(overall_experiment_questions[6])\n",
    "\n",
    "# overall_experiment_df = pd.DataFrame(overall_experiment_questions)\n",
    "\n",
    "# dataset = px_client.upload_dataset(dataframe=overall_experiment_df, dataset_name=\"overall_experiment_questions_all\", input_keys=[\"question\"], output_keys=[\"sql_result\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(overall_experiment_questions[6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "[\n",
    "    {\n",
    "        \"question\": \"What was the most popular product SKU?\",\n",
    "        \"sql_result\": \"   SKU_Coded  Total_Qty_Sold 0    6200700         52262.0\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT SKU_Coded, SUM(Qty_Sold) AS Total_Qty_Sold\\nFROM sales\\nGROUP BY SKU_Coded\\nORDER BY Total_Qty_Sold DESC\\nLIMIT 1;\\n```\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"What was the total revenue across all stores?\",\n",
    "        \"sql_result\": \"   Total_Revenue 0   1.327264e+07\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT SUM(Total_Sale_Value) AS Total_Revenue\\nFROM sales;\\n```\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"Which store had the highest sales volume?\",\n",
    "        \"sql_result\": \"   Store_Number  Total_Sales_Volume 0          2970             59322.0\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT Store_Number, SUM(Total_Sale_Value) AS Total_Sales_Volume\\nFROM sales\\nGROUP BY Store_Number\\nORDER BY Total_Sales_Volume DESC\\nLIMIT 1;\\n```\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"Create a bar chart showing total sales by store\",\n",
    "        \"sql_result\": \"    Store_Number    Total_Sales 0            880  420302.088397 1           1650  580443.007953 2           4180  272208.118542 3            550  229727.498752 4           1100  497509.528013 5           3300  619660.167018 6           3190  335035.018792 7           2970  836341.327191 8           3740  359729.808228 9           2530  324046.518720 10          4400   95745.620250 11          1210  508393.767785 12           330  370503.687331 13          2750  453664.808068 14          1980  242290.828499 15          1760  350747.617798 16          3410  410567.848126 17           990  378433.018639 18          4730  239711.708869 19          4070  322307.968330 20          3080  495458.238811 21          2090  309996.247965 22          1320  592832.067579 23          2640  308990.318559 24          1540  427777.427815 25          4840  389056.668316 26          2860  132320.519487 27          2420  406715.767402 28           770  292968.918642 29          3520  145701.079372 30           660  343594.978075 31          3630  405034.547846 32          2310  412579.388504 33          2200  361173.288199 34          1870  401070.997685\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT Store_Number, SUM(Total_Sale_Value) AS Total_Sales\\nFROM sales\\nGROUP BY Store_Number;\\n```\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"What percentage of items were sold on promotion?\",\n",
    "        \"sql_result\": \"   Promotion_Percentage 0              0.625596\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT \\n    (SUM(CASE WHEN On_Promo = 'Yes' THEN 1 ELSE 0 END) * 100.0) / COUNT(*) AS Promotion_Percentage\\nFROM \\n    sales;\\n```\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"What was the average transaction value?\",\n",
    "        \"sql_result\": \"   Average_Transaction_Value 0                  19.018132\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT AVG(Total_Sale_Value) AS Average_Transaction_Value\\nFROM sales;\\n```\",\n",
    "    },\n",
    "    {\n",
    "        \"question\": \"Create a line chart showing sales in 2021\",\n",
    "        \"sql_result\": \"  sale_month  total_quantity_sold  total_sales_value 0 2021-11-01              43056.0      499984.428193 1 2021-12-01              75724.0      910982.118423\",\n",
    "        \"sql_generated\": \"```sql\\nSELECT MONTH(Sold_Date) AS Month, SUM(Total_Sale_Value) AS Total_Sales\\nFROM sales\\nWHERE YEAR(Sold_Date) = 2021\\nGROUP BY MONTH(Sold_Date)\\nORDER BY MONTH(Sold_Date);\\n```\",\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "CLARITY_LLM_JUDGE_PROMPT = \"\"\"\n",
    "In this task, you will be presented with a query and an answer. Your objective is to evaluate the clarity\n",
    "of the answer in addressing the query. A clear response is one that is precise, coherent, and directly\n",
    "addresses the query without introducing unnecessary complexity or ambiguity. An unclear response is one\n",
    "that is vague, disorganized, or difficult to understand, even if it may be factually correct.\n",
    "\n",
    "Your response should be a single word: either \"clear\" or \"unclear,\" and it should not include any other\n",
    "text or characters. \"clear\" indicates that the answer is well-structured, easy to understand, and\n",
    "appropriately addresses the query. \"unclear\" indicates that the answer is ambiguous, poorly organized, or\n",
    "not effectively communicated. Please carefully consider the query and answer before determining your\n",
    "response.\n",
    "\n",
    "After analyzing the query and the answer, you must write a detailed explanation of your reasoning to\n",
    "justify why you chose either \"clear\" or \"unclear.\" Avoid stating the final label at the beginning of your\n",
    "explanation. Your reasoning should include specific points about how the answer does or does not meet the\n",
    "criteria for clarity.\n",
    "\n",
    "[BEGIN DATA]\n",
    "Query: {query}\n",
    "Answer: {response}\n",
    "[END DATA]\n",
    "Please analyze the data carefully and provide an explanation followed by your response.\n",
    "\n",
    "EXPLANATION: Provide your reasoning step by step, evaluating the clarity of the answer based on the query.\n",
    "LABEL: \"clear\" or \"unclear\"\n",
    "\"\"\"\n",
    "\n",
    "ENTITY_CORRECTNESS_LLM_JUDGE_PROMPT = \"\"\"\n",
    "In this task, you will be presented with a query and an answer. Your objective is to determine whether all\n",
    "the entities mentioned in the answer are correctly identified and accurately match those in the query. An\n",
    "entity refers to any specific person, place, organization, date, or other proper noun. Your evaluation\n",
    "should focus on whether the entities in the answer are correctly named and appropriately associated with\n",
    "the context in the query.\n",
    "\n",
    "Your response should be a single word: either \"correct\" or \"incorrect,\" and it should not include any\n",
    "other text or characters. \"correct\" indicates that all entities mentioned in the answer match those in the\n",
    "query and are properly identified. \"incorrect\" indicates that the answer contains errors or mismatches in\n",
    "the entities referenced compared to the query.\n",
    "\n",
    "After analyzing the query and the answer, you must write a detailed explanation of your reasoning to\n",
    "justify why you chose either \"correct\" or \"incorrect.\" Avoid stating the final label at the beginning of\n",
    "your explanation. Your reasoning should include specific points about how the entities in the answer do or\n",
    "do not match the entities in the query.\n",
    "\n",
    "[BEGIN DATA]\n",
    "Query: {query}\n",
    "Answer: {response}\n",
    "[END DATA]\n",
    "Please analyze the data carefully and provide an explanation followed by your response.\n",
    "\n",
    "EXPLANATION: Provide your reasoning step by step, evaluating whether the entities in the answer are\n",
    "correct and consistent with the query.\n",
    "LABEL: \"correct\" or \"incorrect\"\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "json_tool_call_prompt_template = TOOL_CALLING_PROMPT_TEMPLATE.replace(\n",
    "    \"generate_visualization, lookup_sales_data, analyze_sales_data, run_python_code\",\n",
    "    json.dumps(tools).replace(\"{\", '\"').replace(\"}\", '\"'),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def function_calling_eval(input: str, output: str) -> float:\n",
    "    function_calls = output.get(\"tool_calls\")\n",
    "    if function_calls:\n",
    "        eval_df = pd.DataFrame(\n",
    "            {\"question\": [input.get(\"question\")] * len(function_calls), \"tool_call\": function_calls}\n",
    "        )\n",
    "\n",
    "        tool_call_w_json_tool_defs_evaluator = create_classifier(\n",
    "            name=\"tool_call_w_json_tool_defs\",\n",
    "            llm=eval_model,\n",
    "            prompt_template=json_tool_call_prompt_template,\n",
    "            rails=[\"correct\", \"incorrect\"],\n",
    "        )\n",
    "\n",
    "        tool_call_eval = evaluate_dataframe(\n",
    "            dataframe=eval_df,\n",
    "            evaluators=[tool_call_w_json_tool_defs_evaluator],\n",
    "        )\n",
    "\n",
    "        tool_call_eval[\"score\"] = tool_call_eval.apply(\n",
    "            lambda x: 1 if x[\"label\"] == \"correct\" else 0, axis=1\n",
    "        )\n",
    "        return tool_call_eval[\"score\"].mean()\n",
    "    else:\n",
    "        return 0\n",
    "\n",
    "\n",
    "def code_is_runnable(output: str) -> bool:\n",
    "    \"\"\"Check if the code is runnable\"\"\"\n",
    "    generated_code = output.get(\"tool_responses\")\n",
    "    if not generated_code:\n",
    "        return True\n",
    "\n",
    "    # Find first lookup_sales_data response\n",
    "    generated_code = next(\n",
    "        (r for r in generated_code if r.get(\"tool_name\") == \"generate_visualization\"), None\n",
    "    )\n",
    "    if not generated_code:\n",
    "        return True\n",
    "\n",
    "    # Get the first response\n",
    "    generated_code = generated_code.get(\"tool_response\", \"\")\n",
    "    generated_code = generated_code.strip()\n",
    "    generated_code = generated_code.replace(\"```python\", \"\").replace(\"```\", \"\")\n",
    "    try:\n",
    "        exec(generated_code)\n",
    "        return True\n",
    "    except Exception:\n",
    "        return False\n",
    "\n",
    "\n",
    "def evaluate_sql_result(output, expected) -> bool:\n",
    "    sql_result = output.get(\"tool_responses\")\n",
    "    if not sql_result:\n",
    "        return True\n",
    "\n",
    "    # Find first lookup_sales_data response\n",
    "    sql_result = next((r for r in sql_result if r.get(\"tool_name\") == \"lookup_sales_data\"), None)\n",
    "    if not sql_result:\n",
    "        return True\n",
    "\n",
    "    # Get the first response\n",
    "    sql_result = sql_result.get(\"tool_response\", \"\")\n",
    "\n",
    "    # Extract just the numbers from both strings\n",
    "    result_nums = \"\".join(filter(str.isdigit, sql_result))\n",
    "    expected_nums = \"\".join(filter(str.isdigit, expected.get(\"sql_result\")))\n",
    "    return result_nums == expected_nums\n",
    "\n",
    "\n",
    "def evaluate_clarity(output: str, input: str) -> bool:\n",
    "    df = pd.DataFrame({\"query\": [input.get(\"question\")], \"response\": [output.get(\"final_output\")]})\n",
    "\n",
    "    clarity_evaluator = create_classifier(\n",
    "        name=\"clarity\",\n",
    "        llm=eval_model,\n",
    "        prompt_template=CLARITY_LLM_JUDGE_PROMPT,\n",
    "        rails=[\"clear\", \"unclear\"],\n",
    "    )\n",
    "\n",
    "    response = evaluate_dataframe(\n",
    "        dataframe=df,\n",
    "        evaluators=[clarity_evaluator],\n",
    "    )\n",
    "\n",
    "    return response[\"label\"] == \"clear\"\n",
    "\n",
    "\n",
    "def evaluate_entity_correctness(output: str, input: str) -> bool:\n",
    "    df = pd.DataFrame({\"query\": [input.get(\"question\")], \"response\": [output.get(\"final_output\")]})\n",
    "\n",
    "    entity_correctness_evaluator = create_classifier(\n",
    "        name=\"entity_correctness\",\n",
    "        llm=eval_model,\n",
    "        prompt_template=ENTITY_CORRECTNESS_LLM_JUDGE_PROMPT,\n",
    "        rails=[\"correct\", \"incorrect\"],\n",
    "    )\n",
    "\n",
    "    response = evaluate_dataframe(\n",
    "        dataframe=df,\n",
    "        evaluators=[entity_correctness_evaluator],\n",
    "    )\n",
    "\n",
    "    return response[\"label\"] == \"correct\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_overall_experiment(example: Example) -> str:\n",
    "    with suppress_tracing():\n",
    "        return run_agent_and_track_path_combined(example)\n",
    "\n",
    "\n",
    "experiment = run_experiment(\n",
    "    dataset,\n",
    "    run_overall_experiment,\n",
    "    evaluators=[\n",
    "        function_calling_eval,\n",
    "        evaluate_sql_result,\n",
    "        evaluate_clarity,\n",
    "        evaluate_entity_correctness,\n",
    "        code_is_runnable,\n",
    "    ],\n",
    "    experiment_name=\"Overall Experiment\",\n",
    "    experiment_description=\"Evaluating the overall experiment\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Congratulations! 🎉\n",
    "\n",
    "You've now evaluated every aspect of your agent. If you've made it this far, you're now an expert in evaluating agent routers, tools, and paths!\n",
    "\n",
    "# ![Combined Agent Experiments](https://storage.googleapis.com/arize-phoenix-assets/assets/images/combined-agent-experiments.png)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
